import tushare as ts
import json
import pandas as pd
from datetime import datetime
from datetime import timedelta
import os
import click
import time
from backtrader.utils.cmd import cli
from backtrader.utils import logger


class TradingCalendar:
    def __init__(self, api, exchange, start_date, end_date):
        self.exchange = exchange
        df = api.trade_cal(exchange=exchange, start_date=start_date, end_date=end_date)
        df = df[df["is_open"] == 1]
        self.dates = df["cal_date"].tolist()
        self.timeframe = "1d"

    def first(self) -> str:
        return self.dates[0]

    def last(self) -> str:
        return self.dates[-1]

    def next(self, dt):
        idx = self.dates.index(dt) + 1
        if idx > len(self.dates) - 1:
            return None
        else:
            return self.dates[idx + 1]

    def next_skip(self, dt, skip=2000-1):
        idx = self.dates.index(dt) + skip
        if idx > len(self.dates) - 1:
            idx = len(self.dates) - 1
        return self.dates[idx]


class FutureDownloaded:
    def __init__(self, api, tc: TradingCalendar):
        cfg = json.load(open(os.path.join(os.path.expanduser('~'), ".jupyter", "config.json"), 'r'))
        self.ws_dir = cfg["ws"]
        self.exchange = tc.exchange
        self.future_dir = os.path.join(self.ws_dir, f"data/{self.exchange}")
        self.logger = logger.getLogger(__name__)

        if not os.path.exists(self.future_dir):
            os.makedirs(self.future_dir)
        self.limit = 2000
        self.tc = tc
        self.api = api

    def update_market(self):
        market_path = os.path.join(self.future_dir, f"{self.exchange}_market.csv")
        df = self.api.fut_basic(exchange=self.exchange)
        df.to_csv(market_path, index=False)
        return df

    @classmethod
    def filter_by_condition(cls, df):
        df1 = df.copy()
        df1 = df1.set_index("name")
        df1 = df1.filter(like="主力", axis=0)
        return df1["ts_code"].tolist(), df1.index.tolist()

    def update_k_line(self, symbol, name):
        f_dir = os.path.join(self.future_dir, f'{self.tc.timeframe}')
        if not os.path.exists(f_dir):
            os.mkdir(f_dir)
        f_path = os.path.join(f_dir, f"{name}.csv")
        if os.path.exists(f_path):
            t = pd.read_csv(f_path, dtype={"trade_date": str})
        else:
            t = None

        if t is not None:
            next_date = self.tc.next(dt=t.iloc[-1]["trade_date"])
        else:
            next_date = self.tc.first()

        s_date = next_date
        while s_date:
            time.sleep(3)
            e_date = self.tc.next_skip(s_date)
            self.logger.info("fetch [%s][%s]-[%s / %s]", name, symbol, s_date, e_date)
            df1 = self.api.fut_daily(ts_code=symbol, start_date=s_date, end_date=e_date)
            df1 = df1.sort_values("trade_date", ascending=True)

            if len(df1) > 0:
                if not os.path.exists(f_path):
                    df1.to_csv(f_path, index=False)
                else:
                    df1.to_csv(f_path, index=False, mode="a", header=False)
            s_date = self.tc.next(e_date)
            if not s_date:
                break

    def download(self):
        df = self.update_market()
        # 只下载主力合约
        symbols, names = FutureDownloaded.filter_by_condition(df)
        for idx, symbol in enumerate(symbols):
            name = names[idx]
            self.logger.info("Download [%s], [%s], [%s][%s] - [%s/%s]", self.exchange, self.tc.timeframe, name, symbol, idx + 1, len(symbols))
            self.update_k_line(symbol=symbol, name=name)

@cli.command()
@click.option("--exchange", default="default")
@click.option("--init_date", default="20100101")
def download(exchange, init_date):
    #  SHFE 上期所 DCE 大商所 CFFEX中金所 CZCE郑商所 INE上海国际能源交易所
    if exchange == "default":
        exchange = ["SHFE", "DCE", "CFFEX", "CZCE", "INE"]
    else:
        exchange = exchange.split(",")

    token = "475827415ea4e675902fdef797462578441aa3114c74c37e68792e0b"
    api = ts.pro_api(token)
    format_str = "%Y%m%d"
    today = datetime.today()

    s_date = init_date
    e_date = today.strftime(format_str) if today.hour > 18 else (today - timedelta(days=1)).strftime(
        format_str)

    for ex in exchange:
        tc = TradingCalendar(api=api, exchange=ex, start_date=s_date, end_date=e_date)
        download = FutureDownloaded(api=api, tc=tc)
        download.download()

@click.group()
def btfuture():
    pass


btfuture.add_command(download)

if __name__ == "__main__":
    btfuture()
